import torch

rnn = torch.nn.RNN(
    input_size=4,
    hidden_size=2,
    num_layers=1,
    batch_first=True,
    bidirectional=False
)

input = torch.randn(8, 5, 4)
output, hidden = rnn(input)
print(output.size())
print(hidden.size())






